"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and
# limitations under the License.
"""

from typing import Any, Dict, List, Tuple, Union

import numpy as np
from paddleformers.transformers import AutoTokenizer

from fastdeploy.entrypoints.chat_utils import parse_chat_messages
from fastdeploy.input.utils import IDS_TYPE_FLAG
from fastdeploy.utils import data_processor_logger

from .image_processor import ImageProcessor
from .process_video import read_frames, sample_frames


class DataProcessor:
    """
    Processes multimodal inputs (text, images, videos) into model-ready formats.

    Handles:
    - Tokenization of text with special tokens for visual content
    - Image and video preprocessing
    - Generation of 3D positional embeddings
    - Conversion of chat messages to model inputs

    Attributes:
        tokenizer: Text tokenizer instance
        image_processor: Image/video preprocessor
        image_token: Special token for image placeholders
        video_token: Special token for video placeholders
        vision_start: Token marking start of visual content
    """

    def __init__(
        self,
        model_path: str,
        video_min_frames: int = 4,
        video_max_frames: int = 768,
        tokens_per_second: int = 2,
        tokenizer=None,
        **kwargs,
    ) -> None:
        """
        Initialize the data processor.

        Args:
            model_path: Path to pretrained model
            video_min_frames: Minimum frames to sample from videos
            video_max_frames: Maximum frames to sample from videos
            tokens_per_second: Temporal resolution for positional embeddings
            **kwargs: Additional configuration
        """
        self.min_frames = video_min_frames
        self.max_frames = video_max_frames

        # Initialize tokenizer with left padding and fast tokenizer
        if tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left", use_fast=True)
            self.tokenizer.ignored_index = -100  # Set ignored index for loss calculation
        else:
            self.tokenizer = tokenizer
        self.image_processor = ImageProcessor.from_pretrained(model_path)  # Initialize image processor

        # Convolution sizes for patch aggregation
        self.spatial_conv_size = self.image_processor.merge_size
        self.temporal_conv_size = self.image_processor.temporal_patch_size

        # Special tokens and IDs
        self.image_token = "<|image_pad|>"
        self.video_token = "<|video_pad|>"

        self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
        self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token)

        self.vision_start = "<|vision_start|>"
        self.vision_start_id = self.tokenizer.convert_tokens_to_ids(self.vision_start)

        self.tokens_per_second = tokens_per_second

        self.role_prefixes = {
            "system": "",
            "user": "User: ",
            "bot": "Assistant: ",
            "assistant": "Assistant: ",
        }

    def _pack_outputs(self, outputs):
        """
        Pack and convert all output data into numpy arrays with appropriate types.

        Args:
            outputs (dict): Dictionary containing model outputs with keys:
                - images: List of visual features
                - grid_thw: List of spatial dimensions
                - image_type_ids: List of content type indicators
                - input_ids: List of token IDs
                - token_type_ids: List of type identifiers
                - position_ids: List of position embeddings

        Returns:
            dict: Processed outputs with all values converted to numpy arrays
        """
        # Process visual outputs - stack if exists or set to None if empty
        if not outputs["images"]:
            outputs["images"] = None  # No images case
            outputs["grid_thw"] = None  # No spatial dimensions
            outputs["image_type_ids"] = None  # No type IDs
        else:
            outputs["images"] = np.vstack(outputs["images"])  # Stack image features vertically
            outputs["grid_thw"] = np.vstack(outputs["grid_thw"])  # Stack spatial dimensions
            outputs["image_type_ids"] = np.array(outputs["image_type_ids"])  # Convert to numpy array

        # Convert all outputs to numpy arrays with appropriate types
        outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64)  # Token IDs as int64
        outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64)  # Type IDs as int64
        outputs["position_ids"] = np.concatenate(
            outputs["position_ids"], axis=1, dtype=np.int64
        )  # Concatenate position IDs
        return outputs

    def text2ids(self, text, images=None, videos=None):
        """
        Convert text with image/video placeholders into model inputs.

        Args:
            text: Input text with <|image@placeholder|> and <|video@placeholder|> markers
            images: List of PIL Images corresponding to image placeholders
            videos: List of video data corresponding to video placeholders

        Returns:
            Dict containing:
                - input_ids: Token IDs
                - token_type_ids: Type identifiers (text/image/video)
                - position_ids: 3D positional embeddings
                - images: Preprocessed visual features
                - grid_thw: Spatial/temporal dimensions
                - image_type_ids: Visual content type (0=image, 1=video)
        """

        outputs = {
            "input_ids": [],
            "token_type_ids": [],
            "position_ids": [],
            "images": [],
            "grid_thw": [],
            "image_type_ids": [],
            "labels": [],
            "cur_position": 0,
            "pic_cnt": 0,
            "video_cnt": 0,
        }

        # Define placeholders and their lengths
        IMAGE_PLACEHOLDER = "<|image_pad|>"
        VIDEO_PLACEHOLDER = "<|video_pad|>"
        IMAGE_PLACEHOLDER_LEN = len(IMAGE_PLACEHOLDER)
        VIDEO_PLACEHOLDER_LEN = len(VIDEO_PLACEHOLDER)

        # Initialize tracking variables for text parsing
        st, image_idx, video_idx = 0, 0, 0  # Start position, image counter, video counter
        while st < len(text):
            # Find next image or video placeholder in text
            image_pos = text.find(IMAGE_PLACEHOLDER, st)
            image_pos = len(text) if image_pos == -1 else image_pos  # Set to end if not found
            video_pos = text.find(VIDEO_PLACEHOLDER, st)
            video_pos = len(text) if video_pos == -1 else video_pos  # Set to end if not found
            ed = min(image_pos, video_pos)  # End position is first placeholder found

            self._add_text(text[st:ed], outputs)
            if ed == len(text):
                break

            if ed == image_pos:
                outputs["pic_cnt"] += 1
                self._add_image(images[image_idx], outputs)
                image_idx += 1
                st = ed + IMAGE_PLACEHOLDER_LEN
            else:
                item = videos[video_idx]
                if isinstance(item, dict):
                    frames, meta = self._load_and_process_video(item["video"], item)
                else:
                    frames, meta = self._load_and_process_video(item, {})

                outputs["video_cnt"] += 1
                self._add_video(frames, meta, outputs)
                video_idx += 1
                st = ed + VIDEO_PLACEHOLDER_LEN

        return self._pack_outputs(outputs)

    def request2ids(
        self, request: Dict[str, Any], tgts: List[str] = None
    ) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
        """
        Convert chat request with multimodal messages into model inputs.

        Args:
            request: Dictionary containing:
                - messages: List of chat messages with text/image/video content
                - request_id: Unique identifier for logging
            tgts: Optional target sequences

        Returns:
            Dict with same structure as text2ids() output
        """

        outputs = {
            "input_ids": [],
            "token_type_ids": [],
            "position_ids": [],
            "images": [],
            "grid_thw": [],
            "image_type_ids": [],
            "labels": [],
            "cur_position": 0,
            "pic_cnt": 0,
            "video_cnt": 0,
        }

        # Parse and validate chat messages
        messages = parse_chat_messages(request.get("messages"))
        image_message_list = []  # Store visual content messages

        for msg in messages:
            role = msg.get("role")
            assert role in self.role_prefixes, f"Unsupported role: {role}"

            # Normalize content to list format
            content_items = msg.get("content")
            if not isinstance(content_items, list):
                content_items = [content_items]

            # Collect all visual content items
            for item in content_items:
                if isinstance(item, dict) and item.get("type") in ["image", "video"]:
                    image_message_list.append(item)

        raw_messages = request["messages"]
        request["messages"] = messages

        prompt_token_ids = self.apply_chat_template(request)
        if len(prompt_token_ids) == 0:
            raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
        request["messages"] = raw_messages

        vision_start_index = 0
        vision_message_index = 0
        for i in range(len(prompt_token_ids)):
            if prompt_token_ids[i] == self.vision_start_id:
                self._add_text(prompt_token_ids[vision_start_index : i + 1], outputs)

                vision_start_index = i + 1
                image_message = image_message_list[vision_message_index]

                if image_message["type"] == "image":
                    img = image_message.get("image")
                    if img is None:
                        continue
                    outputs["pic_cnt"] += 1
                    self._add_image(img, outputs)

                elif image_message["type"] == "video":
                    video_bytes = image_message.get("video")
                    if video_bytes is None:
                        continue
                    frames, meta = self._load_and_process_video(video_bytes, image_message)

                    outputs["video_cnt"] += 1
                    self._add_video(frames, meta, outputs)

                vision_message_index += 1

        self._add_text(prompt_token_ids[vision_start_index:], outputs)
        return self._pack_outputs(outputs)

    def _add_text(self, tokens, outputs: Dict) -> None:
        """
        Add text tokens to model inputs dictionary.

        Args:
            tokens: Text string or already tokenized IDs
            outputs: Dictionary accumulating model inputs

        Note:
            - Handles both raw text and pre-tokenized inputs
            - Updates position IDs for 3D embeddings
        """
        if not tokens:
            return None

        if isinstance(tokens, str):
            tokens_str = self.tokenizer.tokenize(tokens)
            tokens = self.tokenizer.convert_tokens_to_ids(tokens_str)

        num_tokens = len(tokens)
        outputs["input_ids"].extend(tokens)
        outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens)

        position_ids = self._compute_text_positions(outputs["cur_position"], num_tokens)
        outputs["position_ids"].append(position_ids)
        outputs["cur_position"] = position_ids.max() + 1

    def _compute_text_positions(self, start_pos: int, num_tokens: int) -> np.ndarray:
        """
        Generate 3D positional embeddings for text tokens.

        Args:
            start_pos: Starting position index
            num_tokens: Number of tokens to generate positions for

        Returns:
            numpy.ndarray: 3D position IDs shaped (3, num_tokens)
        """
        text_array = np.arange(num_tokens).reshape(1, -1)
        text_index = np.broadcast_to(text_array, (3, num_tokens))
        position = text_index + start_pos
        return position

    def _add_image(self, img, outputs: Dict) -> None:
        """
        Add image data to model inputs dictionary.

        Args:
            img: PIL Image to process
            outputs: Dictionary accumulating model inputs

        Note:
            - Preprocesses image and calculates spatial dimensions
            - Adds image token IDs and type markers
            - Generates appropriate position embeddings
        """
        ret = self.image_processor.preprocess(images=[img.convert("RGB")])
        num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
        grid_thw = ret["grid_thw"].tolist()

        outputs["input_ids"].extend([self.image_token_id] * num_tokens)
        outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)

        outputs["images"].append(ret["pixel_values"])
        outputs["grid_thw"].append(grid_thw)
        outputs["image_type_ids"].append(0)

        t, h, w = grid_thw
        position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0)

        outputs["position_ids"].append(position_ids)
        outputs["cur_position"] = position_ids.max() + 1

    def _add_video(self, frames, meta: Dict, outputs: Dict) -> None:
        """
        Add video data to model inputs dictionary.

        Args:
            frames: Video frames as numpy array
            meta: Video metadata containing fps/duration
            outputs: Dictionary accumulating model inputs

        Note:
            - Handles temporal dimension in position embeddings
            - Uses video-specific token IDs and type markers
        """
        ret = self.image_processor.preprocess(images=frames)

        num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
        grid_thw = ret["grid_thw"].tolist()

        outputs["input_ids"].extend([self.video_token_id] * num_tokens)
        outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)

        outputs["images"].append(ret["pixel_values"])
        outputs["grid_thw"].append(grid_thw)
        outputs["image_type_ids"].extend([1] * grid_thw[0])

        fps = meta["fps"]
        second_per_grid_t = self.temporal_conv_size / fps
        t, h, w = grid_thw
        position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)

        outputs["position_ids"].append(position_ids)
        outputs["cur_position"] = position_ids.max() + 1

    def _compute_vision_positions(
        self, start_pos: int, t: int, h: int, w: int, second_per_grid_t: float
    ) -> np.ndarray:
        """
        Generate 3D position IDs for visual inputs.

        Args:
            start_pos: Base position in sequence
            t: Temporal patches (1 for images)
            h: Height in patches
            w: Width in patches
            second_per_grid_t: Time per temporal patch

        Returns:
            np.ndarray: Position IDs for [t,h,w] dimensions
        """
        h //= self.spatial_conv_size
        w //= self.spatial_conv_size

        tn = np.arange(t).reshape(-1, 1)
        tn = np.broadcast_to(tn, (t, h * w))
        tn = tn * int(second_per_grid_t) * self.tokens_per_second
        t_index = tn.flatten()

        hn = np.arange(h).reshape(1, -1, 1)
        h_index = np.broadcast_to(hn, (t, h, w)).flatten()

        wn = np.arange(w).reshape(1, 1, -1)
        w_index = np.broadcast_to(wn, (t, h, w)).flatten()

        position = np.stack([t_index, h_index, w_index]) + start_pos
        return position

    def _load_and_process_video(self, url: str, item: Dict) -> Tuple[np.ndarray, Dict]:
        """
        Load and preprocess video into frames.

        Args:
            url: Video file path or bytes
            item: Dictionary containing processing parameters

        Returns:
            tuple: (frames, metadata) where:
                - frames: Processed video frames as numpy array
                - metadata: Updated video metadata dictionary
        """
        frames, meta = read_frames(url)

        # Apply frame sampling if fps or target_frames specified
        fps = item.get("fps", None)
        num_frames = item.get("target_frames", None)

        if fps is not None or num_frames is not None:
            # Get frame sampling constraints
            min_frames = item.get("min_frames", self.min_frames)
            max_frames = item.get("max_frames", self.max_frames)

            # Sample frames according to specifications
            frames = sample_frames(
                video=frames,
                frame_factor=self.temporal_conv_size,  # Ensure divisible by temporal patch size
                min_frames=min_frames,
                max_frames=max_frames,
                metadata=meta,
                fps=fps,
                num_frames=num_frames,
            )

            # Update metadata with new frame count and fps
            meta["num_of_frame"] = frames.shape[0]
            if fps is not None:
                meta["fps"] = fps  # Use specified fps
                meta["duration"] = frames.shape[0] / fps
            else:
                meta["fps"] = frames.shape[0] / meta["duration"]  # Calculate fps from sampled frames

        return frames, meta

    def apply_chat_template(self, request):
        """
        Apply chat template to convert messages into token sequence.

        Args:
            request: Dictionary containing chat messages

        Returns:
            List of token IDs

        Raises:
            ValueError: If model doesn't support chat templates
        """
        if self.tokenizer.chat_template is None:
            raise ValueError("This model does not support chat_template.")

        raw_prompt = self.tokenizer.apply_chat_template(
            request["messages"],
            tokenize=False,
            add_generation_prompt=request.get("add_generation_prompt", True),
        )
        prompt_token_str = raw_prompt.replace(self.image_token, "").replace(self.video_token, "")
        request["text_after_process"] = raw_prompt

        tokens = self.tokenizer.tokenize(prompt_token_str)
        token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        data_processor_logger.info(
            f"req_id:{request.get('request_id', ''), } prompt: {raw_prompt} tokens: {tokens}, token_ids: {token_ids}"
        )
        return token_ids
